Tank classifier models:

Download imges for training

tank_types = 'merkava mk4','M1 Abrams','water'
path = Path('tanks')

#downloading 150 images to labeled directories
if not path.exists():
    path.mkdir()
    for o in tank_types:
        dest = (path/o)
        dest.mkdir(exist_ok=True)
        urls = search_images_ddg(f'{o} tank', max_images=150)
        download_images(dest, urls=urls)

#deleting not working files
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink);

Bulding a model using multi class loss

Preparing the data for the model

tanks = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
dls = tanks.dataloaders(path)
dls.valid.show_batch(max_n=5, nrows=1)

Training the model and review the results

learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
epoch train_loss valid_loss error_rate time
0 1.626117 0.992537 0.298077 00:54
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
epoch train_loss valid_loss error_rate time
0 0.471223 0.468078 0.173077 00:20
1 0.351875 0.241190 0.086538 00:20
2 0.293782 0.216738 0.067308 00:22
3 0.239018 0.217196 0.067308 00:20
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
plot_top_losses_fix(interp, 10, nrows=10)

Bulding a model using multi label loss

Preparing the data for the model

def parentlabel(x):
  return [x.parent.name] # as get_y recieve a list
tanks = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock), 
    #MultiCategoryBlock(add_na=True)
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parentlabel,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
dls = tanks.dataloaders(path)
dls.valid.show_batch(nrows=1, ncols=5)
multi_learner = cnn_learner(dls, resnet18, metrics=accuracy_multi) #threshold=0.5, Sigmod=True--->partial(accuracy_multi, thresh=0.95
multi_learner.fine_tune(4)
epoch train_loss valid_loss accuracy_multi time
0 0.909140 0.399930 0.852564 00:20
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
epoch train_loss valid_loss accuracy_multi time
0 0.458900 0.303905 0.855769 00:22
1 0.376303 0.216744 0.919872 00:20
2 0.296351 0.175988 0.932692 00:21
3 0.252384 0.164322 0.939103 00:21
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "

Fixing the top lose

def plot_top_losses_fix(interp, k, largest=True, **kwargs):
        losses,idx = interp.top_losses(k, largest)
        if not isinstance(interp.inputs, tuple): interp.inputs = (interp.inputs,)
        if isinstance(interp.inputs[0], Tensor): inps = tuple(o[idx] for o in interp.inputs)
        else: inps = interp.dl.create_batch(interp.dl.before_batch([tuple(o[i] for o in interp.inputs) for i in idx]))
        b = inps + tuple(o[idx] for o in (interp.targs if is_listy(interp.targs) else (interp.targs,)))
        x,y,its = interp.dl._pre_show_batch(b, max_n=k)
        b_out = inps + tuple(o[idx] for o in (interp.decoded if is_listy(interp.decoded) else (interp.decoded,)))
        x1,y1,outs = interp.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            #plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), L(self.preds).itemgot(idx), losses,  **kwargs)
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), interp.preds[idx], losses,  **kwargs)
        #TODO: figure out if this is needed
        #its None means that a batch knows how to show itself as a whole, so we pass x, x1
        #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)
intrpt = ClassificationInterpretation.from_learner(multi_learner)
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
plot_top_losses_fix(intrpt, 10, nrows=10)
target predicted probabilities loss
0 merkava mk4 M1 Abrams TensorBase([9.9981e-01, 7.1941e-02, 6.6779e-04]) 3.738661766052246
1 M1 Abrams merkava mk4 TensorBase([0.4478, 0.9315, 0.0175]) 1.1671565771102905
2 merkava mk4 M1 Abrams TensorBase([0.8890, 0.3600, 0.0038]) 1.0746452808380127
3 merkava mk4 merkava mk4;water TensorBase([0.0189, 0.8435, 0.9244]) 0.923978328704834
4 M1 Abrams M1 Abrams;merkava mk4 TensorBase([0.6619, 0.8650, 0.0027]) 0.806043267250061
5 M1 Abrams M1 Abrams;merkava mk4 TensorBase([0.7605, 0.8781, 0.0056]) 0.794671893119812
6 M1 Abrams M1 Abrams;merkava mk4 TensorBase([0.9500, 0.8916, 0.0201]) 0.7645870447158813
7 M1 Abrams M1 Abrams;merkava mk4 TensorBase([0.9330, 0.8667, 0.0278]) 0.7042757868766785
8 merkava mk4 M1 Abrams;merkava mk4 TensorBase([0.8400, 0.9181, 0.0014]) 0.6398396492004395
9 merkava mk4 M1 Abrams;merkava mk4 TensorBase([0.7792, 0.7856, 0.0041]) 0.5853219628334045
multi_learner.metrics = partial(accuracy_multi, thresh=0.8)
multi_learner.validate()
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
(#2) [0.15455161035060883,0.9615384340286255]
preds, targs = multi_learner.get_preds()
xs = torch.linspace(0.05,0.95,29)
accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]
plt.plot(xs,accs);
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
multi_learner.loss_func.thresh=0.2
multi_learner.loss_func.thresh
0.5
multi_learner.loss_func.thresh=0.8
multi_learner.export('mlmodel.pkl')

same model but with assign loss function

# This is formatted as code
ls=nn.BCEWithLogitsLoss()
multi_learner2 = cnn_learner(dls, resnet18, loss_func=ls, metrics=accuracy_multi)
multi_learner2.fine_tune(4)
multi_learner2.get_preds(with_decoded=False)

Making multi label classifier on the same data---->own loss

#loss_func=BCEWithLogitsLossFlat(thresh=0.7)
lr = cnn_learner(dls, resnet18, loss_func=F.binary_cross_entropy_with_logits, metrics=partial(accuracy_multi, thresh=0.2) )

Simple UI demo

To use our model in an application, we can simply treat the predict method as a regular function.

# path = Path()
learn_inf = load_learner(path/'export.pkl')
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')
def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'

btn_run.on_click(on_click_classify)
VBox([widgets.Label('Select your tank!'), 
      btn_upload, btn_run, out_pl, lbl_pred])

Addtional validation

multi_learner2.predict('/content/DUDU.jpg')
multi_learner2.predict('/content/TetroWtank.jpg')
multi_learner.predict('/content/manytanks.jpg')
multi_learner.predict('/content/merk_abrms2.jpeg')
multi_learner.predict('/content/merkav_abrams.jpg')
multi_learner.predict('/content/mek_abs_3.jpg')
 
#img.show()
#learn.predict(img)[0]